{% comment %}
// vim: set syntax=asm :

/* mmm 48 x 4:

    zmm0 zmm3 zmm6 zmm9
    zmm1 zmm4 zmm7 zmm10
    zmm2 zmm5 zmm8 zmm11

System V ABI:
    args: rdi, rsi, rdx, rcx, r8, r9
    preserve: rbx, rsp, rbp, r12, r13, r14, r15
    scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11
    return: rax (+rdx)

Windows ABI:
    args: RCX, RDX, R8, R9
    preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15
    scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15
    return: rax (+rdx)
*/
{% endcomment %}

{% include "preamble.tmpliq" size:"48x4", suffix:suffix, G:G, arch:"avx512" %}

{{L}}clear:
    vzeroall
    jmp     {{L}}non_linear_loop

{{L}}add_mat_mul:
    mov     rcx,    [rdi + 24]   // B
    mov     rax,    [rdi + 16]   // A

    mov     rbx,    [rdi + 8]    // k
    test    rcx,    rcx
    jz      {{L}}non_linear_loop

{{L}}main_loop_packed_packed:
	{% include "3x4/packed_packed_loop1/avx-512.tmpli" %}

    dec             rbx
    jnz             {{L}}main_loop_packed_packed

    jmp             {{L}}non_linear_loop

{% include "f32_scalars.tmpliq" from:0, to:11 %}
{% include "f32_per_rows.tmpliq" mr:48, from:0, to:11 %}
{% include "f32_per_cols.tmpliq" mr:48, from:0, to:11 %}

{{L}}add_unicast:

    mov     r10,    [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     rbx,    [rdi + 24]          // col stride

    mov     eax,    0

{% for i in (0..3) %}
    pinsrd  xmm14, eax, {{i}}
    add     eax,    esi
{% endfor %}
{% for i in (0..3) %}
    pinsrd  xmm15, eax, {{i}}
    add     eax,    esi
{% endfor %}
{% for i in (0..3) %}
    pinsrd  xmm12, eax, {{i}}
    add     eax,    esi
{% endfor %}
{% for i in (0..3) %}
    pinsrd  xmm13, eax, {{i}}
    add     eax,    esi
{% endfor %}

    vperm2f128      ymm14,  ymm14, ymm15,         32 // ymm14 <- xmm14::xmm15
    vperm2f128      ymm13,  ymm12, ymm13,         32 // ymm12 <- xmm12::xmm13
    vinsertf32x8    zmm14, zmm14, ymm13, 1

{% for i in (0..3) %}
    kxnorw k1,k1,k1
    vgatherdps      zmm12{k1},  [ r10 + zmm14 ]
    add     r10, rbx
    vaddps          zmm{{i | times: 3}},   zmm{{i | times: 3}},   zmm12
{% endfor %}

    imul    esi,    16
    vpbroadcastd    zmm15, esi

{% for j in (1..2) %}
    mov     r10,    [rdi + 8]
    vpaddd          zmm14, zmm14, zmm15

    {% for i in (0..3) %}
        kxnorw k1,k1,k1
        vgatherdps      zmm12{k1},  [ r10 + zmm14 ]
        add     r10, rbx
        vaddps          zmm{{i | times: 3 | plus: j}},   zmm{{i | times: 3 | plus: j}},   zmm12
    {% endfor %}
{% endfor %}

    jmp    {{L}}non_linear_loop

{{L}}add_row_col_products:
    mov             rax, [ rdi + 8 ]
    mov             rbx, [ rdi + 16 ]

    vmovups         zmm12, zmmword ptr [rax]
    vmovups         zmm13, zmmword ptr [rax+64]
    vmovups         zmm15, zmmword ptr [rax+128]

{% for i in (0..3) %}
    vbroadcastss    zmm14, dword ptr [rbx + {{i|times:4}} ]
    vfmadd231ps     zmm{{i | times: 3}}, zmm12, zmm14
    vfmadd231ps     zmm{{i | times: 3 | plus: 1}}, zmm13, zmm14
    vfmadd231ps     zmm{{i | times: 3 | plus: 2}}, zmm15, zmm14
{% endfor %}

    jmp    {{L}}non_linear_loop

{{L}}store:
    mov     r8,     [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     rbx,    [rdi + 24]          // col stride

    // tops of cols
    lea     r9,     [ r8 + rbx ]
    lea     r10,    [ r8 + 2 * rbx ]
    lea     r11,    [ r10 + rbx ]

    {% for word in (0..2) %}
        {% for quarter in (0..3) %}
            {% for r in (0..3) %}
                vextractf32x4 xmm{{r | plus: 12}}, zmm{{r | times: 3 | plus: word}}, {{quarter}}
            {% endfor %}
            {% for row in (0..3) %}
                {% for i in (0..3) %}
                    vextractps  dword ptr [r{{i | plus: 8}}], xmm{{i | plus: 12}}, {{row}}
                    add         r{{i | plus: 8}}, rsi
                {% endfor %}
            {% endfor %}
        {% endfor %}
    {% endfor %}    

    jmp     {{L}}non_linear_loop

{% include "postamble.tmpliq" size:"48x4", suffix:suffix, G:G, L:L, arch:"avx512" %}

